import numpy as np
import torch

"""
this file contains miscellaneous functions usded in `methods.py`
"""

# objective function
def obj(y, Phi, variance, alpha, w, supp):
    """
    objective function: f(w) = ||Phi * w - y||^2 - alpha * <variance, 1(w>0)>
    :param: y: torch tensor of shape (M, 1)
    :param: Phi: M*N torch tensor of shape (M, N)
    :param: variance: torch tensor of shape (N,)
    :param: alpha: positive float value
    :param: w: torch tensor of shape (N, 1)
    :param: supp: list. support of w (indices of non-zero elements).
    """
    M = Phi.shape[0]
    f1 = torch.norm(Phi.mm(w.reshape([-1, 1])) - y) ** 2
    # f2 = - alpha * torch.sum(variance[supp])
    f2 = alpha * (- torch.sum(variance[supp]) + 1)  # +1 is only for better illustration of the loss
    return f1.item() + f2.item(), f1.item(), f2.item()

def square_loss_gradient(Phi, y, w):
    """
    The gradient of f1(w) = ||Phi * w - y||^2
    :param Phi: M*N torch tensor of shape (M, N)
    :param y: torch tensor of shape (M, 1)
    :param w: torch tensor of shape (N, 1)
    :return: torch tensor of shape (N, 1)
    """
    M = Phi.shape[0]
    return 2 * Phi.T.mm(Phi.mm(w) - y)

def regularizer_gradient(w, supp, beta, reg_type):
    """
    The gradient of the regularizer: [beta / N * ||w - mean(w)||^2] in the support of supp
    """
    if supp == 'full':
        N = w.shape[0]
        if reg_type == 'mean':
            w_mean = w.mean()
            return 2 * beta * (w - w_mean) / N
        elif reg_type == 'one':
            return 2 * beta * (w - 1) / N
        else:
            raise ValueError
    else:
        N_supp = len(supp)
        if reg_type == 'mean':
            w_mean = w[supp].mean()
            return 2 * beta * (w[supp] - w_mean) / N_supp
        elif reg_type == 'one':
            return 2 * beta * (w[supp] - 1) / N_supp
        else:
            raise ValueError

def projection_to_sparse_nonnegative(w, k, already_k_sparse=False):
    """
    :return: x = argmin_x ||w-x||_2  s.t.  x is k-sparse and non-negative
    """
    if not already_k_sparse:
        N = w.shape[0]
        indices = w.reshape([-1]).argsort()
        supp_complement = indices[:N - k]
        w[supp_complement] = 0.
    w[w < 0] = 0.
    supp = torch.nonzero(w.reshape([-1])).reshape([-1]).tolist()
    return w, supp

def select_high_variance_data(Phi, ratio, variance):
    # select top-[ratio] high variance data
    M, N = Phi.shape
    N_new = int(N * ratio)
    high_var_supp = variance.argsort(descending=True)[:N_new].tolist()
    Phi = Phi[:, high_var_supp]
    variance = variance[high_var_supp]
    N_orig = N
    return Phi, variance, high_var_supp, N_orig

def step_line_search(Phi, y, w, v):
    """
    line search for step size
    argmin_mu  ||y - Phi(w - mu*v)||_2^2
    """
    a = Phi.mm(v)
    b = y - Phi.mm(w)
    return - (b.T.mm(a)).item() / (a.T.mm(a)).item()

def step_line_search_reg(Phi, y, w, v, k, beta, reg_type):
    """
    line search for step size
    if reg_type is 'mean':
        argmin_mu  ||y - Phi(w - mu*v)||_2^2  +  beta / N * ||(w - mu*v) - mean(w - mu*v)||_2^2
    if reg_type is 'one':
        argmin_mu  ||y - Phi(w - mu*v)||_2^2  +  beta / N * ||(w - mu*v) - 1||_2^2
    """
    M, N = Phi.shape
    a = Phi.mm(v)
    b = y - Phi.mm(w)
    if reg_type == 'mean':
        w_mean = w.sum() / N
        w = w - w_mean
        v_mean = v.sum() / N
        v = v - v_mean
        # return (-b.T.mm(a) / M + beta / k * w.T.mm(v)).item() / (a.T.mm(a) / M + beta / k * v.T.mm(v)).item()
        return (-b.T.mm(a) + beta / N * w.T.mm(v)).item() / (a.T.mm(a) + beta / N * v.T.mm(v)).item()
    elif reg_type == 'one':
        return (-b.T.mm(a) + beta / N * (w - 1).T.mm(v)).item() / (a.T.mm(a) + beta / N * v.T.mm(v) + 1e-30).item()

"""
early stop controller
"""

class early_stop_controller:
    """
    a controller for early stopping iterative descent methods
    """
    def __init__(self, probation=20):
        self.min_loss = np.inf
        self.best_solution= None
        self.counter = 0
        self.probation = probation
        self.loss_window = [np.inf] * probation

    def early_stop(self, solution, validation_loss):
        """
        caution: the saved solution is conducted via '='.
                may need to consider passing a copy of the solution
        """
        local_min_loss = min(self.loss_window)
        self.loss_window.pop(0)
        self.loss_window.append(validation_loss)
        if validation_loss >= local_min_loss:
            self.counter += 1
            if self.probation <= self.counter:
                return True
        else:
            self.counter = 0
            if validation_loss < self.min_loss:
                self.min_loss = validation_loss
                self.best_solution = solution
            return False

    def get_best_solution(self):
        return self.best_solution